R语言机器学习框架tidymodels

您所在的位置:网站首页 patchwork R语言 R语言机器学习框架tidymodels

R语言机器学习框架tidymodels

2023-03-22 01:50| 来源: 网络整理| 查看: 265

我们在构建模型过程中,使用的模型默认的超参数。为了得到稳健准确的模型,需要对模型的超参数进行调优。tidymodels框架提供了tune、rsample、dials等包完成超参数调优。

加载包

library(tidymodels) ## ── Attaching packages ────────────────────────────────────── tidymodels 1.0.0 ── ## ✔ broom 1.0.3 ✔ recipes 1.0.5 ## ✔ dials 1.1.0 ✔ rsample 1.1.1 ## ✔ dplyr 1.1.0 ✔ tibble 3.1.8 ## ✔ ggplot2 3.4.1 ✔ tidyr 1.3.0 ## ✔ infer 1.0.4 ✔ tune 1.0.1 ## ✔ modeldata 1.1.0 ✔ workflows 1.1.3 ## ✔ parsnip 1.0.4 ✔ workflowsets 1.0.0 ## ✔ purrr 1.0.1 ✔ yardstick 1.1.0

载入数据

mpe % drop_na() ## Rows: 456 Columns: 14 ## ── Column specification ──────────────────────────────────────────────────────── ## Delimiter: "," ## dbl (14): MPE, Gender, Age, Fever, Cough, ChestPain, WBCPE, LDHS, TPPE, TPPE... ## ## ℹ Use `spec()` to retrieve the full column specification for this data. ## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.

数据集划分

set.seed(123) split % collect_metrics() ## # A tibble: 10 × 11 ## trees min_n tree_depth learn_r…¹ loss_r…² .metric .esti…³ mean n std_err ## ## 1 141 34 1 2.12e-3 1.94e- 5 accura… binary 0.738 5 0.0175 ## 2 141 34 1 2.12e-3 1.94e- 5 roc_auc binary 0.5 5 0 ## 3 403 6 7 2.21e-4 1.48e+ 1 accura… binary 0.754 5 0.0188 ## 4 403 6 7 2.21e-4 1.48e+ 1 roc_auc binary 0.784 5 0.0238 ## 5 386 17 10 9.02e-9 3.94e- 1 accura… binary 0.708 5 0.0175 ## 6 386 17 10 9.02e-9 3.94e- 1 roc_auc binary 0.764 5 0.0310 ## 7 461 37 11 8.64e-9 6.47e-10 accura… binary 0.738 5 0.0175 ## 8 461 37 11 8.64e-9 6.47e-10 roc_auc binary 0.5 5 0 ## 9 31 17 12 5.89e-8 1.17e- 2 accura… binary 0.708 5 0.0175 ## 10 31 17 12 5.89e-8 1.17e- 2 roc_auc binary 0.764 5 0.0310 ## # … with 1 more variable: .config , and abbreviated variable names ## # ¹learn_rate, ²loss_reduction, ³.estimator

选择最优参数

best_xgb % select_best("accuracy") best_xgb ## # A tibble: 1 × 6 ## trees min_n tree_depth learn_rate loss_reduction .config ## ## 1 403 6 7 0.000221 14.8 Preprocessor1_Model2

查看准确率最高模型

show_best(xgb_res,"accuracy") #按照准确率由高到低进行排序 ## # A tibble: 5 × 11 ## trees min_n tree_depth learn_rate loss_r…¹ .metric .esti…² mean n std_err ## ## 1 403 6 7 2.21e-4 1.48e+ 1 accura… binary 0.754 5 0.0188 ## 2 141 34 1 2.12e-3 1.94e- 5 accura… binary 0.738 5 0.0175 ## 3 461 37 11 8.64e-9 6.47e-10 accura… binary 0.738 5 0.0175 ## 4 386 17 10 9.02e-9 3.94e- 1 accura… binary 0.708 5 0.0175 ## 5 31 17 12 5.89e-8 1.17e- 2 accura… binary 0.708 5 0.0175 ## # … with 1 more variable: .config , and abbreviated variable names ## # ¹loss_reduction, ².estimator

最终模型

final_wf % finalize_workflow(best_xgb) final_wf ## ══ Workflow ════════════════════════════════════════════════════════════════════ ## Preprocessor: Formula ## Model: boost_tree() ## ## ── Preprocessor ──────────────────────────────────────────────────────────────── ## factor(MPE) ~ . ## ## ── Model ─────────────────────────────────────────────────────────────────────── ## Boosted Tree Model Specification (classification) ## ## Main Arguments: ## trees = 403 ## min_n = 6 ## tree_depth = 7 ## learn_rate = 0.000220724331930274 ## loss_reduction = 14.8435496067182 ## ## Computational engine: xgboost

模型在测试集中的表现

final_fit % last_fit(split) final_fit %>% collect_metrics() # A tibble: 2 × 4 .metric .estimator .estimate .config 1 accuracy binary 0.732 Preprocessor1_Model1 2 roc_auc binary 0.729 Preprocessor1_Model1

在测试集中准确率为0.732,AUC为0.729。

# 预测值 xgb_pred % collect_predictions() #ROC曲线 xgb_pred %>% roc_curve(`factor(MPE)`,.pred_0) %>% autoplot()

更多R语言的知识请关注下方微信公众号【PRLearning】数据统计和机器学习 进行交流学习。公众号后台回复“parsnip”索取代码。如果对您有帮助请转发收藏、点赞、点在看。

参考 资料

1、https://www.tidymodels.org/start/tuning/

2、https://tune.tidymodels.org/articles/getting_started.html

3、https://www.tmwr.org/performance.html



【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3